A skewed dataset
is defined by a dataset that has a class imbalance, this leads to poor or failing spark jobs that often get a OOM
(out of memory) error.
When performing a join
onto a skewed dataset
it's usually the case where there is an imbalance on the key
(s) on which the join is performed on. This results in a majority of the data falls onto a single partition, which will take longer to complete than the other partitions.
Some hints to detect skewness is:
- The
key
(s) consist mainly ofnull
values which fall onto a single partition. - There is a subset of values for the
key
(s) that makeup the high percentage of the total keys which fall onto a single partition.
We go through both these cases and see how we can combat it.
Library Imports
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
Template
spark = (
SparkSession.builder
.master("local")
.appName("Exploring Joins")
.config("spark.some.config.option", "some-value")
.getOrCreate()
)
sc = spark.sparkContext
Situation 2: High Frequency Keys
Inital Datasets
customers = spark.createDataFrame([
(1, "John"),
(2, "Bob"),
], ["customer_id", "first_name"])
customers.toPandas()
customer_id | first_name | |
---|---|---|
0 | 1 | John |
1 | 2 | Bob |
orders = spark.createDataFrame([
(i, 1 if i < 95 else 2, "order #{}".format(i)) for i in range(100)
], ["id", "customer_id", "order_name"])
orders.toPandas().tail(6)
id | customer_id | order_name | |
---|---|---|---|
94 | 94 | 1 | order #94 |
95 | 95 | 2 | order #95 |
96 | 96 | 2 | order #96 |
97 | 97 | 2 | order #97 |
98 | 98 | 2 | order #98 |
99 | 99 | 2 | order #99 |
Option 1: Inner Join
df = customers.join(orders, "customer_id")
df.toPandas().tail(10)
customer_id | first_name | id | order_name | |
---|---|---|---|---|
90 | 1 | John | 90 | order #90 |
91 | 1 | John | 91 | order #91 |
92 | 1 | John | 92 | order #92 |
93 | 1 | John | 93 | order #93 |
94 | 1 | John | 94 | order #94 |
95 | 2 | Bob | 95 | order #95 |
96 | 2 | Bob | 96 | order #96 |
97 | 2 | Bob | 97 | order #97 |
98 | 2 | Bob | 98 | order #98 |
99 | 2 | Bob | 99 | order #99 |
df.explain()
== Physical Plan ==
*(5) Project [customer_id#122L, first_name#123, id#126L, order_name#128]
+- *(5) SortMergeJoin [customer_id#122L], [customer_id#127L], Inner
:- *(2) Sort [customer_id#122L ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(customer_id#122L, 200)
: +- *(1) Filter isnotnull(customer_id#122L)
: +- Scan ExistingRDD[customer_id#122L,first_name#123]
+- *(4) Sort [customer_id#127L ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(customer_id#127L, 200)
+- *(3) Filter isnotnull(customer_id#127L)
+- Scan ExistingRDD[id#126L,customer_id#127L,order_name#128]
What Happened:
- We want to find what
order
s eachcustomer
made, so we will bejoin
ing thecustomer
s table to theorder
s table. - When performing the join, we perform a
hashpartitioning
oncustomer_id
. - From our data creation, this means 95% of the data landed onto a single partition.
Results:
- Similar to the
Null Skew
case, this means that single task/partition will take a lot longer than the others, and most likely erroring out.
Option 2: Salt the key, then Join
Helper Function
def data_skew_helper(left, right, key, number_of_partitions, how="inner"):
salt_value = F.lit(F.rand() * number_of_partitions % number_of_partitions).cast('int')
left = left.withColumn("salt", salt_value)
salt_col = F.explode(F.array([F.lit(i) for i in range(number_of_partitions)])).alias("salt")
right = right.select("*", salt_col)
return left.join(right, [key, "salt"], how).drop("salt")
Example
num_of_partitions = 5
left = customers
salt_value = F.lit(F.rand() * num_of_partitions % num_of_partitions).cast('int')
left = left.withColumn("salt", salt_value)
left.toPandas().head(5)
customer_id | first_name | salt | |
---|---|---|---|
0 | 1 | John | 4 |
1 | 2 | Bob | 3 |
right = orders
salt_col = F.explode(F.array([F.lit(i) for i in range(num_of_partitions)])).alias("salt")
right = right.select("*", salt_col)
right.toPandas().head(10)
id | customer_id | order_name | salt | |
---|---|---|---|---|
0 | 0 | 1 | order #0 | 0 |
1 | 0 | 1 | order #0 | 1 |
2 | 0 | 1 | order #0 | 2 |
3 | 0 | 1 | order #0 | 3 |
4 | 0 | 1 | order #0 | 4 |
5 | 1 | 1 | order #1 | 0 |
6 | 1 | 1 | order #1 | 1 |
7 | 1 | 1 | order #1 | 2 |
8 | 1 | 1 | order #1 | 3 |
9 | 1 | 1 | order #1 | 4 |
df = left.join(right, ["customer_id", "salt"])
df.orderBy('id').toPandas().tail(10)
customer_id | salt | first_name | id | order_name | |
---|---|---|---|---|---|
90 | 1 | 4 | John | 90 | order #90 |
91 | 1 | 4 | John | 91 | order #91 |
92 | 1 | 4 | John | 92 | order #92 |
93 | 1 | 4 | John | 93 | order #93 |
94 | 1 | 4 | John | 94 | order #94 |
95 | 2 | 3 | Bob | 95 | order #95 |
96 | 2 | 3 | Bob | 96 | order #96 |
97 | 2 | 3 | Bob | 97 | order #97 |
98 | 2 | 3 | Bob | 98 | order #98 |
99 | 2 | 3 | Bob | 99 | order #99 |
df.explain()
== Physical Plan ==
*(5) Project [customer_id#122L, salt#136, first_name#123, id#126L, order_name#128]
+- *(5) SortMergeJoin [customer_id#122L, salt#136], [customer_id#127L, salt#141], Inner
:- *(2) Sort [customer_id#122L ASC NULLS FIRST, salt#136 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(customer_id#122L, salt#136, 200)
: +- *(1) Filter (isnotnull(salt#136) && isnotnull(customer_id#122L))
: +- *(1) Project [customer_id#122L, first_name#123, cast(((rand(-8040129551223767613) * 5.0) % 5.0) as int) AS salt#136]
: +- Scan ExistingRDD[customer_id#122L,first_name#123]
+- *(4) Sort [customer_id#127L ASC NULLS FIRST, salt#141 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(customer_id#127L, salt#141, 200)
+- Generate explode([0,1,2,3,4]), [id#126L, customer_id#127L, order_name#128], false, [salt#141]
+- *(3) Filter isnotnull(customer_id#127L)
+- Scan ExistingRDD[id#126L,customer_id#127L,order_name#128]
What Happened:
- We created a new
salt
column for both datasets. - On one of the dataset we duplicate the data so we have a row for each
salt
value. - When performing the join, we perform a
hashpartitioning
on[customer_id, salt]
.
Results:
- When we produce a row per
salt
value, we have essentially duplicated(num_partitions - 1) * N
rows. - This created more data, but allowed us to spread the data across more partitions as you can see from
hashpartitioning(customer_id, salt)
.
Summary
All to say:
- By
salt
ing our keys, theskewed
dataset gets divided into smaller partitions. Thus removing the skew. - Again we will sacrifice more resources in order to get a performance gain or a successful run.
- We produced more data by creating
(num_partitions - 1) * N
more data for the right side.